// STL
#include <fstream> // std::ifstream, std::ofstream
#include <string>  // std::string
#include <vector>  // std::vector<>

// Boost
#include <boost/archive/text_oarchive.hpp> // boost::archive::text_oarchive

// jScience
#include "jScience/linalg.hpp" // Matrix<>

// stats++
#include "statsxx/machine_learning/DBN.hpp" // machine_learning::DBN


//
// DESC: Trains a DBN.
//
machine_learning::DBN train_DBN(
                                machine_learning::DBN        dbn,
                                // =====
                                const Matrix<double>        &X,
                                // =====
                                const int                    train_nepoch,
                                const int                    train_npts_per_batch,
                                // -----
                                const int                    train_K_start,
                                const int                    train_K_end,
                                const double                 train_K_rate,
                                // -----
                                const bool                   train_sample_v,
                                const bool                   train_sample_hdata,
                                // -----
                                const double                 train_lr,
                                const std::vector<double>    train_lr_dot,
                                const std::vector<double>    train_lr_adj,
                                const double                 train_lr_min,
                                const double                 train_lr_max,
                                // -----
                                const double                 train_p_start,
                                const double                 train_p_end,
                                const double                 train_p_rate,
                                const std::vector<double>    train_p_dot,
                                const std::vector<double>    train_p_adj,
                                // -----
                                const double                 train_w_penalty,
                                // -----
                                //    int free_energy_npts;
                                //    int free_energy_nepoch;
                                //    int free_energy_window;
                                // -----
                                const int                    train_convg_nno_improvement,
                                const double                 train_convg_max_inc_max_w,
                                // =====
                                const std::string            dbn_file
                                )
{
    // NOTE: the following will round down ... which is probably what we want
    int train_nbatches = X.size(0)/train_npts_per_batch;
    
    int train_free_energy_npts   = train_npts_per_batch;
    int train_free_energy_nepoch = 5;
    int train_free_energy_window = 100;
    
    dbn.train(
              train_nepoch,
              train_nbatches,
              // -----
              train_K_start,
              train_K_end,
              train_K_rate,
              // -----
              train_sample_v,
              train_sample_hdata,
              // -----
              train_lr,
              train_lr_dot,
              train_lr_adj,
              train_lr_min,
              train_lr_max,
              // ---
              train_p_start,
              train_p_end,
              train_p_rate,
              train_p_dot,
              train_p_adj,
              // -----
              train_w_penalty,
              // -----
              train_free_energy_npts,
              train_free_energy_nepoch,
              train_free_energy_window,
              // -----
              train_convg_max_inc_max_w,
              train_convg_nno_improvement,
              // -----
              X,
              X, // NOTE: X is used as both the training and validation sets
              // -----
              dbn_file // NOTE: use dbn_file as a prefix
              );
    
    // SAVE DBN TO ARCHIVE
    {
        std::ofstream ofs(dbn_file, std::ios::out);
        
        boost::archive::text_oarchive oa(ofs);
        
        oa << dbn;
    }
        
    return dbn;    
}
